Gradient reversal pytorch

Inspired from the following tweets:

Basic idea:

# Add something to gradient
f(x) + g(x) - tf.stop_gradients(g(x))

# Reverse gradient
tf.stop_gradient(f(x)*2) - f(x)

In [1]:
import torch
import tensorflow as tf
from torch.autograd import Variable

import numpy as np

In [2]:
def f(X):
    return X*X

def g(X):
    return X**3

In [3]:
X = np.random.randn(10)
X


Out[3]:
array([ 0.01995021, -0.32892969,  0.75804777,  0.172995  ,  0.69747771,
        1.11414039, -0.69194092,  2.43364877,  0.92732815, -0.91409348])

Tensorflow implementation


In [4]:
sess = tf.InteractiveSession()

In [5]:
tf_X = tf.Variable(X)
init_op = tf.global_variables_initializer()

In [6]:
sess.run(init_op)
sess.run(tf_X)


Out[6]:
array([ 0.01995021, -0.32892969,  0.75804777,  0.172995  ,  0.69747771,
        1.11414039, -0.69194092,  2.43364877,  0.92732815, -0.91409348])

In [7]:
forward_op = f(tf_X)

In [8]:
sess.run(forward_op)


Out[8]:
array([  3.98010770e-04,   1.08194738e-01,   5.74636421e-01,
         2.99272683e-02,   4.86475162e-01,   1.24130881e+00,
         4.78782241e-01,   5.92264633e+00,   8.59937506e-01,
         8.35566890e-01])

In [9]:
gradient_op = tf.gradients(forward_op, tf_X)

In [10]:
sess.run(gradient_op)


Out[10]:
[array([ 0.03990041, -0.65785937,  1.51609554,  0.34598999,  1.39495543,
         2.22828078, -1.38388185,  4.86729754,  1.85465631, -1.82818696])]

In [11]:
X*2 # This should match the gradient above


Out[11]:
array([ 0.03990041, -0.65785937,  1.51609554,  0.34598999,  1.39495543,
        2.22828078, -1.38388185,  4.86729754,  1.85465631, -1.82818696])

Modify the gradients

Keep forward pass the same. The trick is to add $g(x)$, such that $g'(x)$ is the gradient modifier, during the forward pass and substract it as well. But stop gradients from flowing through the substraction part.

$f(x) + g(x) - g(x)$ will lead to gradients $f'(x) + g'(x) -g'(x)$. Since gradients don't flow through $-g'(x)$, hence we get new gradients as $f'(x) + g'(x)$


In [12]:
gradient_modifier_op = g(tf_X)

In [13]:
sess.run(gradient_modifier_op)


Out[13]:
array([  7.94039737e-06,  -3.55884610e-02,   4.35601858e-01,
         5.17726764e-03,   3.39305584e-01,   1.38299228e+00,
        -3.31289026e-01,   1.44136410e+01,   7.97444260e-01,
        -7.63786246e-01])

In [14]:
modified_forward_op = (f(tf_X) + g(tf_X) - tf.stop_gradient(g(tf_X)))
modified_backward_op = tf.gradients(modified_forward_op, tf_X)

In [15]:
sess.run(modified_forward_op)


Out[15]:
array([  3.98010770e-04,   1.08194738e-01,   5.74636421e-01,
         2.99272683e-02,   4.86475162e-01,   1.24130881e+00,
         4.78782241e-01,   5.92264633e+00,   8.59937506e-01,
         8.35566890e-01])

In [16]:
sess.run(modified_backward_op)


Out[16]:
[array([  0.04109445,  -0.33327516,   3.2400048 ,   0.4357718 ,
          2.85438092,   5.95220721,   0.05246488,  22.63523654,
          4.43446883,   0.67851371])]

In [17]:
2*X + 3*(X**2) # This should match the gradients above


Out[17]:
array([  0.04109445,  -0.33327516,   3.2400048 ,   0.4357718 ,
         2.85438092,   5.95220721,   0.05246488,  22.63523654,
         4.43446883,   0.67851371])

Gradient reversal

Here the modifying function $g(x)$ is simply the $-2*f(x)$, this will make the gradients $-f'(x)$.


In [18]:
gradient_reversal_op = (tf.stop_gradient(2*f(tf_X)) - f(tf_X))
gradient_reversal_grad_op = tf.gradients(gradient_reversal_op, tf_X)

In [19]:
sess.run(gradient_reversal_op)


Out[19]:
array([  3.98010770e-04,   1.08194738e-01,   5.74636421e-01,
         2.99272683e-02,   4.86475162e-01,   1.24130881e+00,
         4.78782241e-01,   5.92264633e+00,   8.59937506e-01,
         8.35566890e-01])

In [20]:
sess.run(gradient_reversal_grad_op)


Out[20]:
[array([-0.03990041,  0.65785937, -1.51609554, -0.34598999, -1.39495543,
        -2.22828078,  1.38388185, -4.86729754, -1.85465631,  1.82818696])]

In [21]:
sess.run((gradient_op[0] + gradient_reversal_grad_op[0])) # This should be zero. Signifying grad is reversed.


Out[21]:
array([ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.])

Pytoch case


In [22]:
def zero_grad(X):
    if X.grad is not None:
        X.grad.data.zero_()

In [23]:
torch_X = Variable(torch.FloatTensor(X), requires_grad=True)

In [24]:
torch_X.data.numpy()


Out[24]:
array([ 0.01995021, -0.32892969,  0.75804776,  0.172995  ,  0.6974777 ,
        1.11414039, -0.6919409 ,  2.43364882,  0.92732817, -0.91409349], dtype=float32)

In [25]:
f(torch_X).data.numpy()


Out[25]:
array([  3.98010772e-04,   1.08194746e-01,   5.74636400e-01,
         2.99272705e-02,   4.86475140e-01,   1.24130881e+00,
         4.78782207e-01,   5.92264652e+00,   8.59937549e-01,
         8.35566938e-01], dtype=float32)

In [26]:
g(torch_X).data.numpy()


Out[26]:
array([  7.94039715e-06,  -3.55884619e-02,   4.35601830e-01,
         5.17726783e-03,   3.39305550e-01,   1.38299227e+00,
        -3.31288993e-01,   1.44136410e+01,   7.97444284e-01,
        -7.63786316e-01], dtype=float32)

In [27]:
zero_grad(torch_X)
f_X = f(torch_X)
f_X.backward(torch.ones(f_X.size()))
torch_X.grad.data.numpy()


Out[27]:
array([ 0.03990041, -0.65785939,  1.51609552,  0.34599   ,  1.3949554 ,
        2.22828078, -1.38388181,  4.86729765,  1.85465634, -1.82818699], dtype=float32)

In [28]:
2*X


Out[28]:
array([ 0.03990041, -0.65785937,  1.51609554,  0.34598999,  1.39495543,
        2.22828078, -1.38388185,  4.86729754,  1.85465631, -1.82818696])

Modify gradients


In [29]:
modified_gradients_forward = lambda x: f(x) + g(x) - g(x).detach()

In [30]:
zero_grad(torch_X)
modified_grad = modified_gradients_forward(torch_X)
modified_grad.backward(torch.ones(modified_grad.size()))
torch_X.grad.data.numpy()


Out[30]:
array([  0.04109445,  -0.33327514,   3.24000454,   0.43577182,
         2.85438085,   5.95220757,   0.05246484,  22.63523865,
         4.43446875,   0.67851377], dtype=float32)

In [31]:
2*X + 3*(X*X) # It should be same as above


Out[31]:
array([  0.04109445,  -0.33327516,   3.2400048 ,   0.4357718 ,
         2.85438092,   5.95220721,   0.05246488,  22.63523654,
         4.43446883,   0.67851371])

Gradient reversal


In [32]:
gradient_reversal = lambda x: (2*f(x)).detach() - f(x)

In [33]:
zero_grad(torch_X)
grad_reverse = gradient_reversal(torch_X)
grad_reverse.backward(torch.ones(grad_reverse.size()))
torch_X.grad.data.numpy()


Out[33]:
array([-0.03990041,  0.65785939, -1.51609552, -0.34599   , -1.3949554 ,
       -2.22828078,  1.38388181, -4.86729765, -1.85465634,  1.82818699], dtype=float32)

In [34]:
-2*X # It should be same as above


Out[34]:
array([-0.03990041,  0.65785937, -1.51609554, -0.34598999, -1.39495543,
       -2.22828078,  1.38388185, -4.86729754, -1.85465631,  1.82818696])

Pytorch backward hooks


In [35]:
# Gradient reversal
zero_grad(torch_X)
f_X = f(torch_X)
f_X.register_hook(lambda grad: -grad)
f_X.backward(torch.ones(f_X.size()))
torch_X.grad.data.numpy()


Out[35]:
array([-0.03990041,  0.65785939, -1.51609552, -0.34599   , -1.3949554 ,
       -2.22828078,  1.38388181, -4.86729765, -1.85465634,  1.82818699], dtype=float32)

In [36]:
-2*X


Out[36]:
array([-0.03990041,  0.65785937, -1.51609554, -0.34598999, -1.39495543,
       -2.22828078,  1.38388185, -4.86729754, -1.85465631,  1.82818696])

In [37]:
# Modified grad example
zero_grad(torch_X)
h = torch_X.register_hook(lambda grad: grad + 3*(torch_X*torch_X))
f_X = f(torch_X)
f_X.backward(torch.ones(f_X.size()))
h.remove()
torch_X.grad.data.numpy()


Out[37]:
array([  0.04109445,  -0.33327514,   3.24000454,   0.43577182,
         2.85438085,   5.95220757,   0.05246484,  22.63523865,
         4.43446875,   0.67851377], dtype=float32)

In [38]:
2*X + 3*(X*X) # It should be same as above


Out[38]:
array([  0.04109445,  -0.33327516,   3.2400048 ,   0.4357718 ,
         2.85438092,   5.95220721,   0.05246488,  22.63523654,
         4.43446883,   0.67851371])

In [ ]: